Setup

# memory clean up
rm(list = ls()) # remove all objects
gc() # garbage collection
##          used (Mb) gc trigger (Mb) limit (Mb) max used (Mb)
## Ncells 524651 28.1    1164288 62.2         NA   669428 35.8
## Vcells 968910  7.4    8388608 64.0      16384  1851873 14.2
# Import libraries

library(tidyverse)
library(corrr)
library(knitr)
library(kableExtra)
library(GGally)
library(tidymodels)
library(rsample)
library(ggplot2)
#Change this if you want to run it
knitr::opts_knit$set(root.dir = "/Users/miguelkiszkurno/Documents/metlife/test-docker/src")

#seed to be used in the future
g_seed = 730 
#Load the dataset
ds <- read.table(paste0("/Users/miguelkiszkurno/Documents/metlife/test-docker/datasets/dataset.csv"),
                        sep=",", dec=".", header = TRUE, fill = TRUE)

# fijamos semilla
set.seed(g_seed)

# Split the data
train_test <- initial_split(ds, prop = 0.9)
ds <- training(train_test)
ds_test <- testing(train_test)

# Check dimensions
ds %>%
  dim_desc() 
## [1] "[1,204 x 7]"
ds_test %>%
  dim_desc() 
## [1] "[134 x 7]"
#General view of the dataset
glimpse(ds)
## Rows: 1,204
## Columns: 7
## $ age      <int> 43, 43, 63, 43, 36, 52, 56, 18, 64, 50, 48, 44, 41, 49, 30, 4…
## $ sex      <chr> "female", "male", "male", "male", "male", "male", "female", "…
## $ bmi      <dbl> 20.045, 38.060, 28.310, 35.310, 29.700, 41.800, 27.200, 23.08…
## $ children <int> 2, 2, 0, 2, 0, 2, 0, 0, 0, 3, 1, 1, 3, 0, 3, 2, 0, 0, 0, 0, 1…
## $ smoker   <chr> "yes", "yes", "no", "no", "no", "yes", "no", "no", "yes", "no…
## $ region   <chr> "northeast", "southeast", "northwest", "southeast", "southeas…
## $ charges  <dbl> 19798.055, 42560.430, 13770.098, 18806.145, 4399.731, 47269.8…

Preliminary analysis

#Gather information
exp_table =  ds %>%
                gather(., 
                      key = "variables", 
                      value = "values") %>% # Grouping by variable
                group_by(variables) %>% 
                summarise(unique_values = n_distinct(values),
                percentaje_missing = sum(is.na(values))/nrow(ds)*100) %>% 
                arrange(desc(percentaje_missing), unique_values) # order by percentaje

#View the table
exp_table
# Want to check factor variables
regions <- data.frame(campo = "region", valor = unique(ds$region))
smokers <- data.frame(campo = "smoker", valor = unique(ds$smoker))
sex <- data.frame(campo = "sex", valor = unique(ds$sex))

# Combine the datasets to see everything together
df_unique_values <- do.call(rbind, list(regions, smokers, sex))

# View the dataset
df_unique_values
transform_factor_variables <- function (data_set){
  
  data_set$region <- factor(data_set$region)
  data_set$sex <- factor(data_set$sex)
  data_set$smoker <- data_set$smoker == "yes"

  return (data_set)
}

ds <- transform_factor_variables(ds)
ds_test <-transform_factor_variables (ds_test)

# Another quick look to the resulting dataset
glimpse(ds)
## Rows: 1,204
## Columns: 7
## $ age      <int> 43, 43, 63, 43, 36, 52, 56, 18, 64, 50, 48, 44, 41, 49, 30, 4…
## $ sex      <fct> female, male, male, male, male, male, female, male, female, m…
## $ bmi      <dbl> 20.045, 38.060, 28.310, 35.310, 29.700, 41.800, 27.200, 23.08…
## $ children <int> 2, 2, 0, 2, 0, 2, 0, 0, 0, 3, 1, 1, 3, 0, 3, 2, 0, 0, 0, 0, 1…
## $ smoker   <lgl> TRUE, TRUE, FALSE, FALSE, FALSE, TRUE, FALSE, FALSE, TRUE, FA…
## $ region   <fct> northeast, southeast, northwest, southeast, southeast, southe…
## $ charges  <dbl> 19798.055, 42560.430, 13770.098, 18806.145, 4399.731, 47269.8…
#Check some standard descriptive metrics of the dataset.
ds_metrics <- ds %>%
      select_if(is.numeric) %>% 
      gather(., 
             key = "variable", 
             value = "values") %>% # agrupamos por las variables del set
             group_by(variable) %>%
                  summarise(min = min(values),
                            mean = mean(values),
                            median = median(values),
                            sd = sd(values),
                            max = max(values) ) %>% 
                                      arrange(variable) # orden
ds_metrics
#Miro la relación entre las variables 
numericas <- ds %>%
  select(where(is.numeric), sex)

ggpairs(numericas,  mapping = aes(color = sex))

numericas <- ds %>%
  select(where(is.numeric), smoker)

ggpairs(numericas,  mapping = aes(color = smoker))

numericas <- ds %>%
  select(where(is.numeric), region)
ggpairs(numericas,  mapping = aes(color = region))

ds %>% 
 correlate() %>% 
  shave() %>% 
  fashion() 
ds %>% 
 correlate() %>% 
  rplot()

$ age 19, 18, 28, 33, 32, 31, 46, 37, 37, 60, 25, 62, 23, 56, 27, 19, 52, 23, 56, 30, 60, 30, 18, 34, 37, 59, 63, 55, 23, 31, 22, 18, 19, 63, 28, 19, 62, 26, … $ bmi 27.900, 33.770, 33.000, 22.705, 28.880, 25.740, 33.440, 27.740, 29.830, 25.840, 26.220, 26.290, 34.400, 39.820, 42.130, 24.600, 30.780, 23.845, 40.300, … $ children 0, 1, 3, 0, 0, 0, 1, 3, 2, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 2, 3, 0, 2, 1, 2, 0, 0, 5, 0, 1, 0, 3, 0, 1, 0, 0, 2, 1, 2, 1, 0, 2, 0, 0, 1, 0,…

# Charges by sex
ggplot(ds, aes(x = sex, y = charges, group = sex, fill = sex))+
  geom_boxplot() +
  scale_y_continuous(limits = c(0, 70000))

# Charges by region
ggplot(ds, aes(x = region, y = charges, group = region, fill = region))+
  geom_boxplot() +
  scale_y_continuous(limits = c(0, 70000)) # definimos escala del eje y

# Charges by region smoker
ggplot(ds, aes(x = smoker, y = charges, group = smoker, fill = smoker))+
  geom_boxplot() +
  scale_y_continuous(limits = c(0, 70000)) # definimos escala del eje y

limite_superior_outliers = IQR(ds$charges) * 1.5 + quantile(ds$charges, 0.75)[[1]]
limite_superior_outliers
## [1] 34699.56
outliers_charges <- ds %>% filter(charges>limite_superior_outliers) %>% arrange(., desc(charges))
outliers_charges
limite_inferior_charges = quantile(ds$charges, 0.25)[[1]] - IQR(ds$charges) * 1.5
limite_inferior_charges 
## [1] -13082.39

Linear Regression: Predict charges using all the features

Model Fiting

# fit the model
mdl_all = lm(formula = charges ~ ., data = ds)

# View the model summary (dataframe format)
tidy_mdl_all <- tidy(mdl_all, conf.int = TRUE) %>% arrange(p.value)
tidy_mdl_all
# View the model summary (plain text format)
summary_mdl_all = summary(mdl_all)
summary_mdl_all
## 
## Call:
## lm(formula = charges ~ ., data = ds)
## 
## Residuals:
##    Min     1Q Median     3Q    Max 
## -11267  -2921  -1031   1392  29943 
## 
## Coefficients:
##                  Estimate Std. Error t value Pr(>|t|)    
## (Intercept)     -11896.08    1042.20 -11.414  < 2e-16 ***
## age                261.26      12.71  20.555  < 2e-16 ***
## sexmale           -211.89     354.00  -0.599  0.54958    
## bmi                334.52      30.02  11.142  < 2e-16 ***
## children           436.03     144.77   3.012  0.00265 ** 
## smokerTRUE       23790.90     440.36  54.026  < 2e-16 ***
## regionnorthwest   -133.25     508.87  -0.262  0.79348    
## regionsoutheast   -945.58     506.10  -1.868  0.06196 .  
## regionsouthwest   -979.75     502.03  -1.952  0.05122 .  
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 6114 on 1195 degrees of freedom
## Multiple R-squared:  0.7455, Adjusted R-squared:  0.7438 
## F-statistic: 437.5 on 8 and 1195 DF,  p-value: < 2.2e-16

Model diagnosis

# Plot the residuals
plot(mdl_all)

Linear Regression: Predict log(charges) using all the features

Previous calculations

# Will use the log to normalize the charges field
ds = ds %>% 
  mutate(log.charges = log(charges))

head(ds)
#Charges (without applying log)
ggplot(data = ds, aes(x = round(charges))) + 
  geom_histogram(col = "white", aes( fill = ..count..), alpha = 0.75) +
  labs(title = "charges Histogram") +
  labs(x = "charges") +
  theme_bw()
## Warning: The dot-dot notation (`..count..`) was deprecated in ggplot2 3.4.0.
## ℹ Please use `after_stat(count)` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

#Charges (applying log)
ggplot(data = ds, aes(x = log.charges)) + 
  geom_histogram(col = "white", aes( fill = ..count..), alpha = 0.75) +
  labs(title = "log.charges Histogram") +
  labs(x = "log.charges") +
  theme_bw()
## `stat_bin()` using `bins = 30`. Pick better value with `binwidth`.

Model fitting

#Fit the model using log.charges as target and all the other features as predictors
mdl_log_all = lm(formula = log.charges ~ age + sex + bmi + children + smoker + region, data = ds)
mdl_log_all
## 
## Call:
## lm(formula = log.charges ~ age + sex + bmi + children + smoker + 
##     region, data = ds)
## 
## Coefficients:
##     (Intercept)              age          sexmale              bmi  
##         7.04796          0.03460         -0.07586          0.01288  
##        children       smokerTRUE  regionnorthwest  regionsoutheast  
##         0.09773          1.54577         -0.03791         -0.14564  
## regionsouthwest  
##        -0.12260
# View the model summary (plain text format)
summary_mdl_log_all = summary(mdl_log_all)
summary_mdl_log_all
## 
## Call:
## lm(formula = log.charges ~ age + sex + bmi + children + smoker + 
##     region, data = ds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -1.07451 -0.19320 -0.05250  0.05643  2.11358 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      7.0479599  0.0750908  93.859  < 2e-16 ***
## age              0.0345976  0.0009158  37.779  < 2e-16 ***
## sexmale         -0.0758583  0.0255059  -2.974 0.002997 ** 
## bmi              0.0128824  0.0021632   5.955 3.41e-09 ***
## children         0.0977299  0.0104304   9.370  < 2e-16 ***
## smokerTRUE       1.5457667  0.0317279  48.719  < 2e-16 ***
## regionnorthwest -0.0379093  0.0366642  -1.034 0.301365    
## regionsoutheast -0.1456352  0.0364645  -3.994 6.90e-05 ***
## regionsouthwest -0.1225950  0.0361715  -3.389 0.000724 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.4405 on 1195 degrees of freedom
## Multiple R-squared:  0.7656, Adjusted R-squared:  0.764 
## F-statistic: 487.9 on 8 and 1195 DF,  p-value: < 2.2e-16
# View the model summary (dataframe format)
tidy_summary_mdl_log_all <- tidy(summary_mdl_log_all, conf.int = TRUE) %>% arrange(p.value)
tidy_summary_mdl_log_all
# Coeficients plot
ggplot(tidy_summary_mdl_log_all, aes(estimate, term, xmin = conf.low, xmax = conf.high, height = 0)) +
  geom_point(color = "forestgreen",size=2) +
  geom_vline(xintercept = 0, lty = 4, color = "black") +
  geom_errorbarh(color = "forestgreen", size=1) +
  theme_bw() +
  labs(y = "Coeficientes β", x = "Estimación")
## Warning: Using `size` aesthetic for lines was deprecated in ggplot2 3.4.0.
## ℹ Please use `linewidth` instead.
## This warning is displayed once every 8 hours.
## Call `lifecycle::last_lifecycle_warnings()` to see where this warning was
## generated.

Model diagnosis

# Plot the residuals
plot(mdl_log_all)

# Linear Regression: Adding iteration between age and smoker (mdl_log_all_it_smoker_age)

Model Fiting

# Fit the model
mdl_log_all_it_smoker_age = lm(formula = log.charges ~ sex + bmi + children + smoker * age + region, data = ds)
mdl_log_all_it_smoker_age
## 
## Call:
## lm(formula = log.charges ~ sex + bmi + children + smoker * age + 
##     region, data = ds)
## 
## Coefficients:
##     (Intercept)          sexmale              bmi         children  
##         6.82522         -0.07461          0.01233          0.09998  
##      smokerTRUE              age  regionnorthwest  regionsoutheast  
##         2.71788          0.04059         -0.04093         -0.13390  
## regionsouthwest   smokerTRUE:age  
##        -0.13592         -0.03035
# View the model (dataframe format)
summary_mdl_log_all_it_smoker_age = summary(mdl_log_all_it_smoker_age)
summary_mdl_log_all_it_smoker_age
## 
## Call:
## lm(formula = log.charges ~ sex + bmi + children + smoker * age + 
##     region, data = ds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.94319 -0.19701 -0.07306  0.03123  2.25290 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      6.8252241  0.0710901  96.008  < 2e-16 ***
## sexmale         -0.0746083  0.0235648  -3.166  0.00158 ** 
## bmi              0.0123298  0.0019989   6.168 9.44e-10 ***
## children         0.0999821  0.0096378  10.374  < 2e-16 ***
## smokerTRUE       2.7178782  0.0867664  31.324  < 2e-16 ***
## age              0.0405865  0.0009434  43.022  < 2e-16 ***
## regionnorthwest -0.0409261  0.0338744  -1.208  0.22722    
## regionsoutheast -0.1339012  0.0336991  -3.973 7.51e-05 ***
## regionsouthwest -0.1359236  0.0334313  -4.066 5.10e-05 ***
## smokerTRUE:age  -0.0303543  0.0021149 -14.353  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.407 on 1194 degrees of freedom
## Multiple R-squared:  0.8001, Adjusted R-squared:  0.7986 
## F-statistic:   531 on 9 and 1194 DF,  p-value: < 2.2e-16
# View the model (plain text format)
tidy_summary_mdl_log_all_it_smoker_age <- tidy(summary_mdl_log_all_it_smoker_age, conf.int = TRUE) %>% arrange(p.value)
tidy_summary_mdl_log_all_it_smoker_age

Model diagnosis

# Plot the residuals
plot(mdl_log_all_it_smoker_age)

Linear Regression: Predict Log(charges) Adding iteration between smoker and bmi (mdl_log_all_it_smoker_age_bmi)

# Fit the model
mdl_log_all_it_smoker_age_bmi = lm(formula = log.charges ~ sex + smoker * bmi + children + smoker * age + region, data = ds)
mdl_log_all_it_smoker_age_bmi
## 
## Call:
## lm(formula = log.charges ~ sex + smoker * bmi + children + smoker * 
##     age + region, data = ds)
## 
## Coefficients:
##     (Intercept)          sexmale       smokerTRUE              bmi  
##        7.145613        -0.084892         1.287737         0.001423  
##        children              age  regionnorthwest  regionsoutheast  
##        0.101371         0.041184        -0.051525        -0.140367  
## regionsouthwest   smokerTRUE:bmi   smokerTRUE:age  
##       -0.145729         0.049605        -0.032662
# View the model (plain text format)
summary_mdl_log_all_it_smoker_age_bmi = summary(mdl_log_all_it_smoker_age_bmi)
summary_mdl_log_all_it_smoker_age_bmi
## 
## Call:
## lm(formula = log.charges ~ sex + smoker * bmi + children + smoker * 
##     age + region, data = ds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -0.67814 -0.15413 -0.06715 -0.01521  2.30224 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      7.1456131  0.0733937  97.360  < 2e-16 ***
## sexmale         -0.0848919  0.0224366  -3.784 0.000162 ***
## smokerTRUE       1.2877371  0.1516852   8.490  < 2e-16 ***
## bmi              0.0014235  0.0021350   0.667 0.505060    
## children         0.1013706  0.0091696  11.055  < 2e-16 ***
## age              0.0411841  0.0008991  45.808  < 2e-16 ***
## regionnorthwest -0.0515253  0.0322395  -1.598 0.110262    
## regionsoutheast -0.1403670  0.0320642  -4.378 1.30e-05 ***
## regionsouthwest -0.1457293  0.0318162  -4.580 5.13e-06 ***
## smokerTRUE:bmi   0.0496047  0.0044140  11.238  < 2e-16 ***
## smokerTRUE:age  -0.0326617  0.0020224 -16.150  < 2e-16 ***
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.3872 on 1193 degrees of freedom
## Multiple R-squared:  0.8192, Adjusted R-squared:  0.8177 
## F-statistic: 540.7 on 10 and 1193 DF,  p-value: < 2.2e-16
# View the model (dataframe format)
tidy_summary_mdl_log_all_it_smoker_age_bmi <- tidy(summary_mdl_log_all_it_smoker_age_bmi, conf.int = TRUE) %>% arrange(p.value)
tidy_summary_mdl_log_all_it_smoker_age_bmi

Model diagnosis

plot(mdl_log_all_it_smoker_age_bmi)

Linear Regression: Predict Log(charges) Adding iteration between age and bmi (mdl_log_all_it_age_bmi)

mdl_log_all_it_age_bmi = lm(formula = log.charges ~ sex + age * bmi + children + smoker + region, data = ds)
mdl_log_all_it_age_bmi
## 
## Call:
## lm(formula = log.charges ~ sex + age * bmi + children + smoker + 
##     region, data = ds)
## 
## Coefficients:
##     (Intercept)          sexmale              age              bmi  
##       6.8100214       -0.0755310        0.0407891        0.0206788  
##        children       smokerTRUE  regionnorthwest  regionsoutheast  
##       0.0977903        1.5453779       -0.0383864       -0.1469699  
## regionsouthwest          age:bmi  
##      -0.1213469       -0.0002011
# View the model summary (plain text format)
summary_mdl_log_all_it_age_bmi = summary(mdl_log_all_it_age_bmi)
summary_mdl_log_all_it_age_bmi
## 
## Call:
## lm(formula = log.charges ~ sex + age * bmi + children + smoker + 
##     region, data = ds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -1.16889 -0.18730 -0.04645  0.05017  2.11999 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      6.8100214  0.1899691  35.848  < 2e-16 ***
## sexmale         -0.0755310  0.0254979  -2.962 0.003114 ** 
## age              0.0407891  0.0046324   8.805  < 2e-16 ***
## bmi              0.0206788  0.0061133   3.383 0.000741 ***
## children         0.0977903  0.0104268   9.379  < 2e-16 ***
## smokerTRUE       1.5453779  0.0317178  48.723  < 2e-16 ***
## regionnorthwest -0.0383864  0.0366527  -1.047 0.295173    
## regionsoutheast -0.1469699  0.0364646  -4.030 5.92e-05 ***
## regionsouthwest -0.1213469  0.0361701  -3.355 0.000819 ***
## age:bmi         -0.0002011  0.0001475  -1.363 0.172992    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 0.4404 on 1194 degrees of freedom
## Multiple R-squared:  0.766,  Adjusted R-squared:  0.7642 
## F-statistic: 434.2 on 9 and 1194 DF,  p-value: < 2.2e-16
# View the model summary (dataframe format)
tidy_summary_mdl_log_all_it_age_bmi <- tidy(summary_mdl_log_all_it_age_bmi, conf.int = TRUE) %>% arrange(p.value)
tidy_summary_mdl_log_all_it_age_bmi

model diagnosis

# Plot the residuals
plot(mdl_log_all_it_age_bmi)

Linear Regression: Predict charges Adding iteration between smoker, bmi and Age (mdl_all_it_smoker_age_bmi)

mdl_all_it_smoker_age_bmi = lm(formula = charges ~ sex + smoker * bmi + children + smoker * age + region, data = ds)
mdl_all_it_smoker_age_bmi
## 
## Call:
## lm(formula = charges ~ sex + smoker * bmi + children + smoker * 
##     age + region, data = ds)
## 
## Coefficients:
##     (Intercept)          sexmale       smokerTRUE              bmi  
##       -2188.841         -510.983       -19916.279           21.288  
##        children              age  regionnorthwest  regionsoutheast  
##         471.292          265.821         -432.415        -1156.795  
## regionsouthwest   smokerTRUE:bmi   smokerTRUE:age  
##       -1234.228         1429.975           -2.294
# View the model summary (plain text format)
summary_mdl_all_it_smoker_age_bmi = summary(mdl_all_it_smoker_age_bmi)
summary_mdl_all_it_smoker_age_bmi
## 
## Call:
## lm(formula = charges ~ sex + smoker * bmi + children + smoker * 
##     age + region, data = ds)
## 
## Residuals:
##      Min       1Q   Median       3Q      Max 
## -14302.0  -1905.3  -1371.2   -497.7  30446.7 
## 
## Coefficients:
##                   Estimate Std. Error t value Pr(>|t|)    
## (Intercept)      -2188.841    931.119  -2.351  0.01890 *  
## sexmale           -510.983    284.645  -1.795  0.07288 .  
## smokerTRUE      -19916.279   1924.375 -10.349  < 2e-16 ***
## bmi                 21.288     27.085   0.786  0.43205    
## children           471.292    116.331   4.051 5.42e-05 ***
## age                265.821     11.406  23.306  < 2e-16 ***
## regionnorthwest   -432.415    409.011  -1.057  0.29063    
## regionsoutheast  -1156.795    406.787  -2.844  0.00453 ** 
## regionsouthwest  -1234.228    403.641  -3.058  0.00228 ** 
## smokerTRUE:bmi    1429.975     55.999  25.536  < 2e-16 ***
## smokerTRUE:age      -2.294     25.658  -0.089  0.92877    
## ---
## Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
## 
## Residual standard error: 4912 on 1193 degrees of freedom
## Multiple R-squared:  0.836,  Adjusted R-squared:  0.8346 
## F-statistic: 608.1 on 10 and 1193 DF,  p-value: < 2.2e-16
# View the model summary (dataframe format)
tidy_summary_mdl_all_it_smoker_age_bmi <- tidy(summary_mdl_all_it_smoker_age_bmi, conf.int = TRUE) %>% arrange(p.value)
tidy_summary_mdl_all_it_smoker_age_bmi

Model diagnosis

# Plot the residuals
plot(mdl_all_it_smoker_age_bmi)

Model Comparison

Evaluation in training data

get_metrics_from_model <-function(model, test_data = NULL){
  
  target = model[["terms"]][[2]]
  formula = deparse(model[["terms"]][[3]])
  mdl_name = paste0(target, "~", formula)

  if (is.null(test_data)){
    cat("No text data")
    eval <- broom::augment(model, ds)
  }
  else{
    eval <- broom::augment(model, newdata = test_data)
  }

    
  if (target == "charges"){
    metrics = metrics(data = eval, truth = charges, estimate = .fitted) %>% mutate(.estimate = round(.estimate, 4)) %>% mutate(model = mdl_name)
  }
  else{
    # log models require to do the exp
    eval = eval %>%  mutate(fitted_antilog = exp(.fitted))
    
    # Compute RMSE, R2 and MAE for charge (not logcharges) so we can compare all the models
    metrics = metrics(data = eval, truth = charges, estimate = fitted_antilog) %>% mutate(.estimate = round(.estimate, 4)) %>% mutate(model = mdl_name) 
  }
    
  return (metrics)
}

# all the models I want to compare in the models list
models <- list(mdl_all = mdl_all, 
               mdl_log_all = mdl_log_all, 
               mdl_log_all_it_smoker_age = mdl_log_all_it_smoker_age,
               mdl_log_all_it_smoker_age_bmi = mdl_log_all_it_smoker_age_bmi, 
               mdl_log_all_it_age_bmi = mdl_log_all_it_age_bmi, 
               mdl_all_it_smoker_age_bmi = mdl_all_it_smoker_age_bmi
               )
metrics_total <- list()

for (m in models) {
 
  metrics = get_metrics_from_model (m)
  metrics_total <- bind_rows(metrics_total, metrics)
}
## No text dataNo text dataNo text dataNo text dataNo text dataNo text data
#Compare RMSE
metrics_total  %>%
  filter(.metric == "rmse") %>%
  arrange(.estimate)
#Compare MAE
metrics_total  %>%
  filter(.metric == "mae") %>%
  arrange(.estimate)
#Compare R-Squared
metrics_total  %>%
  filter(.metric == "rsq") %>%
  arrange(desc(.estimate))

Para Test

Vamos a calcular el MAE, RMSE y R2 para los 6 modelos

metrics_total <- list()

for (m in models) {
 
  metrics = get_metrics_from_model (m, ds_test)
  metrics_total <- bind_rows(metrics_total, metrics)
}
#Compare RMSE
metrics_total  %>%
  filter(.metric == "rmse") %>%
  arrange(.estimate)
#Compare MAE
metrics_total  %>%
  filter(.metric == "mae") %>%
  arrange(.estimate)
#Compare R-Squared
metrics_total  %>%
  filter(.metric == "rsq") %>%
  arrange(desc(.estimate))